import scanpy as sc
import pandas as pd
import logging
from tqdm import tqdm
import torch
from sklearn.preprocessing import LabelEncoder

class XDict(dict):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._num = self[list(self.keys())[0]].shape[0]

    # No longer required
    # def check(self):
    #     for k, v in self.items():
    #         assert isinstance(v, torch.Tensor), f'{k} is not a torch.Tensor'
    #         assert v.shape[0] == self._num, f'{k} contains {v.shape[0]} samples. Expected: f{self._num}'

    def size(self):
        logging.info('Deprecated function: Xdict.size()')
        return self._num

    # Not usable for sparse data
    # def drop(self, ratio):
    #     drop_num = int(self._num * ratio)
    #     keep_idx = np.random.permutation(self._num)[drop_num:]
    #     for k, v in self.items():
    #         self[k] = v[keep_idx]
    #     return self

def clean_batches(data):
    # Remove batch with less than 1000 cells
    sc.pp.filter_cells(data, min_counts=5)
    remove_list = []
    for b in data.obs['batch'].value_counts().reset_index().iterrows():
        if b[1]['batch'] < 1000:
            remove_list.append(b[1]['index'])
    data = data[~data.obs['batch'].isin(set(remove_list))]
    return data

import numpy as np

def balanced_partition(data, n_partitions):
    # Sort batches
    batch_sizes = [(i, min(6000, len(batch))) for i, batch in enumerate(data)]
    batch_sizes.sort(key=lambda x: x[1], reverse=True)

    # inialize partitions
    partitions = [[] for _ in range(n_partitions)]
    partition_sizes = [0] * n_partitions
    partition_sizes[0] = 30000

    # Fill partitions
    for i, batch_size in batch_sizes:
        min_partition = np.argmin(partition_sizes)
        partitions[min_partition].append(i)
        partition_sizes[min_partition] += batch_size

    return partitions

def stratified_sample_genes_by_sparsity(data, boundaries=None, seed=10):
    df = data.to_df()
    zero_rates = 1 - df.astype(bool).sum(axis=0) / df.shape[0]
    if boundaries is None:
        # boundaries = [0, zero_rates.mean() - zero_rates.std(), zero_rates.mean(), 
        #               min(zero_rates.mean() + zero_rates.std(), 1)]
        boundaries = [0, 0.75, 0.9, 0.95, 1]
    gene_group = pd.cut(zero_rates, boundaries, labels=False)
    # gene_df = pd.DataFrame({'zero_rates': zero_rates, 'gene_group': gene_group})
    zero_rates = zero_rates.groupby(gene_group, group_keys=False)
    samples = zero_rates.apply(lambda x: x.sample(min(len(x), 25), random_state=seed))
    return list(samples.index)

def data_setup(adata, return_sparse=True, device='cpu'):
    # Data Setup
    order = torch.arange(adata.shape[0], device=device)
    lb = LabelEncoder().fit(adata.obs['batch'])
    batch_labels = lb.transform(adata.obs['batch'])
    # print(lb.classes_)
    seq_list = [[], [], [], []] if return_sparse else []
    batch_list = []
    order_list = []
    dataset_list = []
    coord_list = []
    if adata.obs['cell_type'].nunique()!=2:
        labels = LabelEncoder().fit_transform(adata.obs['cell_type'])
    else:
        labels = adata.obs['cell_type'].astype(int).values
        print(labels.mean())
    label_list = []
    for batch in tqdm(range(batch_labels.max() + 1)):
        if return_sparse:
            x = (adata.X[batch_labels == batch]).astype(float)
            x = list(map(torch.from_numpy, [x.indptr, x.indices, x.data])) + [torch.tensor(x.shape)]
            for i in range(4):
                seq_list[i].append(x[i].to(device))
        else:
            x = torch.from_numpy(adata.X[batch_labels == batch].todense()).float()
            seq_list.append(x.to(device))
        # x = torch.sparse_csr_tensor(x.indptr, x.indices, x.data, (x.shape[0], x.shape[1])).to_sparse().float()
        # seq_list.append(x)
        order_list.append(order[batch_labels == batch])
        dataset_list.append(adata.obs['Dataset'][batch_labels == batch][0])
        batch_list.append(torch.from_numpy(batch_labels[batch_labels == batch]).to(device))
        if adata.obs['platform'][batch_labels == batch][0] == 'cosmx':
            coord_list.append(torch.from_numpy(adata.obs[['x_FOV_px', 'y_FOV_px']][batch_labels == batch].values).to(device))
        else:
            coord_list.append(torch.zeros(order_list[-1].shape[0], 2).to(device) - 1)
        label_list.append(torch.from_numpy(labels[batch_labels == batch].astype(int)).to(device))
    del order
    return seq_list, batch_list, batch_labels, order_list, dataset_list, coord_list, label_list